# We want to create a csv file for annoation from all the different system's output.
# First we will accumulate all the outputs of the different systems in a list.
# Then we will randomize/shuffle that list and write them in buckets of 10

import os
import csv
import ast
import random
import string
from sacremoses import MosesTokenizer
mt = MosesTokenizer()

random.seed(901)

def print_list(l):
	for e in l:
		print(e)
	print()

correct_answers_dict = dict()

def read_beam_outputs(beam_output_file):
	# dictionary of id and question and response
	qr_dict = dict()

	with open(beam_output_file, "r") as reader:
		for line in reader:
			line = line.strip()
			line = line.replace("PRED SCORE: ", "")
			line = line.replace("PRED AVG SCORE: ", "")
			if line.startswith("SENT "):
				sent_part, src_list_part = line.split(": [")
				id = int(sent_part.replace("SENT ", ""))
				src_list_part = "[" + src_list_part
				src_list = ast.literal_eval(src_list_part)
				question = ' '.join(src_list[:src_list.index('|||')])
				answer = ' '.join(src_list[src_list.index('|||')+1:])
				qr_dict[id] = (question, answer)
			if line.startswith("PRED "):
				# print(line)
				pred_part, response = line.split(": ", 1)
				id = int(pred_part.replace("PRED ", ""))
				qr_dict[id] = (*qr_dict[id], response)
	sorted_qr = sorted(qr_dict.items(), key=lambda kv: kv[0])
	final_qar_list = list()
	for id, (q, a, r) in sorted_qr:
		correct_a = correct_answers_dict[q]
		final_qar_list.append((id, q, correct_a, r))
		# print(id, q, ":::", a)
		# print(r)
	return final_qar_list

def read_dialogpt_model_outputs(dialogpt_output_file):
	# dictionary of id and question and response
	final_qar_list = list()
	q_flag = True
	r_flag = False
	with open(dialogpt_output_file, "r") as reader:

		for i,line in enumerate(reader):
			if q_flag:
				# read q and id
				q_id, question = line.split(":", 1)
				question = question.strip()
				# print(q_id, question)
				q_flag = False
				r_flag = True
			elif not q_flag and r_flag:
				# if not line.strip():
				# 	print("WTF:", i, q_id, question)
				# 	print(":", line, ":")
				# read the response for this id and save
				line_spl = line.split("\t")
				response = line_spl[-1].strip()
				# Save q and response id in the list
				q = question.split(" ||| ")[0].strip()
				final_qar_list.append((int(q_id)+1, q, correct_answers_dict[q], response))
				r_flag = False
			elif not line.split():
				# reset q_flag
				q_flag = True
	return final_qar_list

def read_bert_model_output(bert_output_file, id_q_a_list):
	# Since the number of question generated by bert model is less than the actual test set. 
	# We will keep track of the question for which this model couldn't generate a response
	q_dict = dict()
	response_line = True
	with open(bert_output_file, "r") as reader:
		question, gold_answer, response = None, None, None
		for line in reader:
			line = line.strip()
			if response_line:
				# read the line and extract the question and response
				line_spl = line.split("\t")
				id = int(line_spl[2])
				question = line_spl[3]
				response = line_spl[5]
				q_dict[question] = response
				response_line = False
			if not line:
				# reset the line tracker for the next question
				response_line = True
	# Now we want to generate an id,q,a,r list
	# here we will simply copy the a into r if q is not present in the q_dict
	final_qar_list = list()
	found_count = 0
	for id, q, a in id_q_a_list:
		if q in q_dict:
			found_count += 1
			correct_a = correct_answers_dict[q]
			final_qar_list.append((id, q, correct_a, q_dict[q]))
		else:
			correct_a = correct_answers_dict[q]
			# print(correct_a, "::", a)
			final_qar_list.append((id, q, correct_a, a))
	print(found_count)
	return final_qar_list


def read_dev_test_q_and_a(src_dev_file):
	global correct_answers_dict
	# we will make a list of tuples containing ids, question and answer from the dev test src file
	q_list = list()
	q_set = set()
	with open(src_dev_file, "r") as reader:
		for i, line in enumerate(reader):
			q, a = line.strip().split(" ||| ")
			q_set.add(q)
			correct_answers_dict[q] = a
			q_list.append(((i+1), q, a))
	#NOTE: just checked. All questions are unique
	# print("set size:", len(q_set))
	return q_list

def read_decomposable_attn_output(decomposable_attn_model_output_file, id_q_a_list):
	# Since the number of question generated by decomposable attention model is less than the actual test set. 
	# We will keep track of the question for which this model couldn't generate a response
	q_dict = dict()
	with open(decomposable_attn_model_output_file, "r") as reader:
		line_tracker = 0
		instance_tracker = 1
		question, gold_answer, response = None, None, None
		for line in reader:
			line = line.strip()
			# first line is the question
			if line_tracker == 0:
				question = line
				line_tracker += 1
			# second line is the gold answer
			elif line_tracker == 1:
				gold_answer = line
				line_tracker += 1
			# Third line is the first/top ranked response from the model
			elif line_tracker == 2:
				_, response, _, _ = line.split("\t")
				# add to the dict here
				q_dict[question] = (gold_answer, response)
				line_tracker += 1
			if not line:
				# reset the line tracker for the next question
				line_tracker = 0
				instance_tracker += 1
	# Now we want to generate an id,q,a,r list
	# here we will simply copy the a into r if q is not present in the q_dict
	final_qar_list = list()
	found_count = 0
	for id, q, a in id_q_a_list:
		if q in q_dict:
			found_count += 1
			correct_a = correct_answers_dict[q]
			final_qar_list.append((id, q, correct_a, q_dict[q][1]))
		else:
			correct_a = correct_answers_dict[q]
			final_qar_list.append((id, q, correct_a, a))
	return final_qar_list

def read_coqa_output(coqa_output_file):
	# dictionary of id and question and response
	qr_dict = dict()

	with open(coqa_output_file, "r") as reader:
		current_id = -1
		for line in reader:
			line = line.strip()
			if line.startswith("Q "):
				current_id_str, question = line.split("\t\t:")
				new_id = int(current_id_str.replace("Q ", ""))
				# print(new_id)
				# print(question)
				if new_id != current_id:
					# print(current_id, new_id)
					current_id = new_id
				qr_dict[current_id] = question
			if line.startswith("Gold\t\t:"):
				gold_answer = line.replace("Gold\t\t:", "")
				qr_dict[current_id] = (qr_dict[current_id], gold_answer)
			if line.startswith("Coqa Pred\t:"):
				coqa_prediction_answer = line.replace("Coqa Pred\t:", "")
				qr_dict[current_id] = (*qr_dict[current_id], coqa_prediction_answer)
				# print(current_id, qr_dict[current_id])
	sorted_qr = sorted(qr_dict.items(), key=lambda kv: kv[0])
	final_qar_list = list()
	for id, (q, a, r) in sorted_qr:
		correct_a = correct_answers_dict[q]
		final_qar_list.append((id, q, correct_a, r))
	return final_qar_list

def read_opennmt_output(opennmt_output_file):
	# dictionary of id and question and response
	qr_dict = dict()

	with open(opennmt_output_file, "r") as reader:
		for line in reader:
			line = line.strip()
			line = line.replace("PRED SCORE: ", "")
			line = line.replace("PRED AVG SCORE: ", "")
			if line.startswith("SENT "):
				sent_part, src_list_part = line.split(": [")
				id = int(sent_part.replace("SENT ", ""))
				src_list_part = "[" + src_list_part
				src_list = ast.literal_eval(src_list_part)
				question = ' '.join(src_list[:src_list.index('|||')])
				answer = ' '.join(src_list[src_list.index('|||')+1:])
				qr_dict[id] = (question, answer)
			if line.startswith("PRED "):
				pred_part, response = line.split(": ", 1)
				id = int(pred_part.replace("PRED ", ""))
				qr_dict[id] = (*qr_dict[id], response)
	sorted_qr = sorted(qr_dict.items(), key=lambda kv: kv[0])
	final_qar_list = list()
	for id, (q, a, r) in sorted_qr:
		correct_a = correct_answers_dict[q]
		final_qar_list.append((id, q, correct_a, r))
		# print(id, q, ":::", a)
		# print(r)
	return final_qar_list

def read_quac_output(quac_output_file):
	final_qar_list = list()
	correct_answer_count = 0
	current_response = ""
	counter = 1
	with open(quac_output_file, "r") as reader:
		next_line_question = True
		for line in reader:
			line = line.strip()
			if line and next_line_question:
				current_q, current_a = line.split(" ||| ")
				next_line_question = False
			elif line and not next_line_question:
				current_response = line
				if current_a in current_response:
					correct_answer_count += 1
				else:
					# print(current_a, " :: ", current_response)
					pass
				current_response =  mt.tokenize(line[2:-2].strip(), return_str=True, escape=False).lower()
				# Remove punctuation at the end if present
				# 879
				if current_response[-1] in ['.',',',';']:
					current_response = current_response[:-1].strip()
				normalized_q = mt.tokenize(current_q.lower().strip(), return_str=True, escape=False)
				try:
					correct_a = correct_answers_dict[normalized_q]
				except KeyError:
					print_list(correct_answers_dict.keys())
					exit()
				final_qar_list.append((counter, normalized_q, correct_a, current_response))
			else:
				next_line_question = True
				counter += 1
	print("Correct quac answers:", correct_answer_count)
	return final_qar_list

def attach_label_to_qar_list(qar_list, label):
	# Also simultaneously check how many questions did each model get correct
	new_qar_list = list()
	correct_count = 0
	for tup in qar_list:
		_, q, a, r = tup
		a = a.lower()
		if a in r:
			correct_count += 1
		# elif label == 'quac':
		# 	# print(a, "::", r)
		# 	pass
		new_qar_list.append((*tup, label))
	print(label, ":", correct_count, "/", len(qar_list))
	return new_qar_list

DATA_FOLDER = "data2"
src_test_file = os.path.join(DATA_FOLDER, "src_squad_seq2seq_dev_moses_test.txt")
bert_model_output = os.path.join(DATA_FOLDER, "bert_softmax_predictions_on_squad_dev_test_0_to_822.txt")
coqa_output = os.path.join(DATA_FOLDER, "coqa_predictions_on_squad_seq2seq_dev_moses_test.txt")
quac_output = os.path.join(DATA_FOLDER, "quac_answers_on_squad_dev_test.txt")
ss_pgn_output = os.path.join(DATA_FOLDER, "ss_pgn_squad_model_squad_dev_test_beam_output.txt")
ss_pgn_pre_output = os.path.join(DATA_FOLDER, "ss_pgn_pre_squad_model_squad_dev_test_beam_output.txt")
ss_plus_pgn_output = os.path.join(DATA_FOLDER, "ss_plus_pgn_squad_model_squad_dev_test_beam_output.txt")
ss_plus_pgn_pre_output = os.path.join(DATA_FOLDER, "ss_plus_pgn_pre_squad_model_squad_dev_test_beam_output.txt")
dialogpt_ss_output = os.path.join(DATA_FOLDER, "dialoGPT_ss_scratch_predictions_on_squad_dev_test_with_squad_model_length_normalized.txt")
dialogpt_ss_plus_output = os.path.join(DATA_FOLDER, "dialoGPT_ss_plus_scratch_predictions_on_squad_dev_test_with_squad_model_length_normalized.txt")
dialogpt_ss_finetuned_small_output = os.path.join(DATA_FOLDER, "dialoGPT_ss_finetuned_predictions_on_squad_dev_test_with_squad_model_length_normalized.txt")
dialogpt_ss_plus_finetuned_small_output = os.path.join(DATA_FOLDER, "dialoGPT_ss_plus_finetuned_predictions_on_squad_dev_test_with_squad_model_length_normalized_scores_new_best.txt")
dialogpt_ss_finetuned_opensub_qa_output = os.path.join(DATA_FOLDER, "dialoGPT_ss_finetuned_opensub_qa_predictions_on_squad_dev_test_with_squad_model_length_normalized_scores_new_best.txt")
dialogpt_ss_plus_finetuned_opensub_qa_output = os.path.join(DATA_FOLDER, "dialoGPT_ss_plus_finetuned_opensub_qa_predictions_on_squad_dev_test_with_squad_model_length_normalized_scores_new_best.txt")
dialogpt_ss_plus_finetuned_small_oracle_output = os.path.join(DATA_FOLDER, "dialoGPT_ss_plus_finetuned_predictions_on_squad_dev_test_length_normalized_scores_new_best.txt")

## LABELS
# bert + rules model responses = bert
# coqa model response = c
# quac model response = quac
# ss_pgn model response = ss_pgn
# ss_pgn_pre model response = ss_pgn_pre
# ss_plus_pgn model response = ss+_pgn
# ss_plus_pgn_pre model response = ss+_pgn_pre
# dialogpt_ss model response = gpt_ss
# dialogpt_ss_plus model response = gpt_ss+
# dialogpt_ss_finetuned_small model response = gpt_ss_sm
# dialogpt_ss_plus_finetuned_small model response = gpt_ss+_sm
# dialogpt_ss_finetuned_opensub_qa model response = gpt_ss_oqa
# dialogpt_ss_plus_finetuned_opensub_qa model response = gpt_ss+_oqa
# dialogpt_ss_plus_finetuned_small model response + oracle = gpt_ss+_sm_o

id_q_a_list = read_dev_test_q_and_a(src_test_file)

bert_qar_list = attach_label_to_qar_list(read_bert_model_output(bert_model_output, id_q_a_list), "bert")
c_qar_list = attach_label_to_qar_list(read_coqa_output(coqa_output), "c")
quac_qar_list = attach_label_to_qar_list(read_quac_output(quac_output), "quac")
ss_pgn_qar_list = attach_label_to_qar_list(read_beam_outputs(ss_pgn_output), "ss_pgn")
ss_pgn_pre_qar_list = attach_label_to_qar_list(read_beam_outputs(ss_pgn_pre_output), "ss_pgn_pre")
ss_plus_pgn_qar_list = attach_label_to_qar_list(read_beam_outputs(ss_plus_pgn_output), "ss+_pgn")
ss_plus_pgn_pre_qar_list = attach_label_to_qar_list(read_beam_outputs(ss_plus_pgn_pre_output), "ss+_pgn_pre")
gpt_ss_qar_list = attach_label_to_qar_list(read_dialogpt_model_outputs(dialogpt_ss_output), "gpt_ss")
gpt_ss_plus_qar_list = attach_label_to_qar_list(read_dialogpt_model_outputs(dialogpt_ss_plus_output), "gpt_ss+")
gpt_ss_sm_qar_list = attach_label_to_qar_list(read_dialogpt_model_outputs(dialogpt_ss_plus_finetuned_small_output), "gpt_ss_sm")
gpt_ss_plus_sm_qar_list = attach_label_to_qar_list(read_dialogpt_model_outputs(dialogpt_ss_plus_finetuned_small_output), "gpt_ss+_sm")
gpt_ss_oqa_qar_list = attach_label_to_qar_list(read_dialogpt_model_outputs(dialogpt_ss_finetuned_opensub_qa_output), "gpt_ss_oqa")
gpt_ss_plus_oqa_qar_list = attach_label_to_qar_list(read_dialogpt_model_outputs(dialogpt_ss_plus_finetuned_opensub_qa_output), "gpt_ss+_oqa")
gpt_ss_plus_sm_oracle_qar_list = attach_label_to_qar_list(read_dialogpt_model_outputs(dialogpt_ss_plus_finetuned_small_oracle_output), "gpt_ss+_sm_o")

all_qars = [bert_qar_list, c_qar_list, quac_qar_list, ss_pgn_qar_list, ss_pgn_pre_qar_list, ss_plus_pgn_qar_list, ss_plus_pgn_pre_qar_list, gpt_ss_qar_list, gpt_ss_plus_qar_list, gpt_ss_sm_qar_list, gpt_ss_plus_sm_qar_list, gpt_ss_oqa_qar_list, gpt_ss_plus_oqa_qar_list, gpt_ss_plus_sm_oracle_qar_list]

# Group responses based on q,a
total_unique_response = 0
all_qar_dict = dict()
for id, q, a in id_q_a_list:
	# For each question we will aggregate same responses from different models into one instance. While appending the labels
	# Gather responses and labels from all models for current q,a
	current_responses_and_labels = dict()
	for qar_list in all_qars:
		qar_id, qar_q, qar_a, qar_response, qar_label = qar_list[id-1]
		if id != qar_id:
			print(qar_label)
			print(qar_id)
			print(id)
			print("Serious error in gathering!")
			exit()
		if qar_response in current_responses_and_labels:
			current_responses_and_labels[qar_response] += ":" + qar_label
		else:
			current_responses_and_labels[qar_response] = qar_label
	total_unique_response += len(current_responses_and_labels)
	all_qar_dict[q] = (id, a, current_responses_and_labels)
	# print(id, len(current_responses_and_labels))
print(total_unique_response, 15000)
print(len(all_qar_dict))

def convert_dict_to_array(qars):
	qars_list = list()
	# sort by ids
	sorted_qars = sorted(qars.items(), key=lambda kv: kv[1][0])
	# N = len(sorted_qars)
	# First batch
	N_start = 0
	N_end = 100
	# Second batch
	N_start = 100
	N_end = 300
	# Third batch
	N_start = 300
	N_end = 400
	# Fourth batch
	N_start = 400
	N_end = 500
	# convert all instances into list and shuffle
	all_instances = list()
	instance_counts = 0
	for i in range(N_start, N_end):
		q, (id, a, responses_dict) = sorted_qars[i]
		sorted_responses_dict = sorted(responses_dict.items(), key=lambda kv: kv[0])
		for response, label in sorted_responses_dict:
			all_instances.append((id, q, a, response, label))
	random.shuffle(all_instances)
	print(len(all_instances))
	return all_instances

def save_qars_to_mturk_csv(qars, csv_save_file):
	n = 10
	with open(csv_save_file, "w") as csv_file:
		writer = csv.writer(csv_file, delimiter=',')
		writer.writerow(["id0","question0","answer0","response0","label0","id1","question1","answer1","response1","label1","id2","question2","answer2","response2","label2","id3","question3","answer3","response3","label3","id4","question4","answer4","response4","label4","id5","question5","answer5","response5","label5","id6","question6","answer6","response6","label6","id7","question7","answer7","response7","label7","id8","question8","answer8","response8","label8","id9","question9","answer9","response9","label9"])
		all_qars_list = convert_dict_to_array(qars)
		# print_list(all_qars_list)
		for i in range(0, len(all_qars_list), n):
			hit_qars = all_qars_list[i:i+n]
			if len(hit_qars) < n:
				# fill the remaining from the top of the list
				hit_qars.extend(all_qars_list[0:n-len(hit_qars)])
				# print("yess:", len(hit_qars))
			hit_qars = list(sum([(id,q,a,r,label) for (id,q,a,r,label) in hit_qars], ()))
			writer.writerow(hit_qars)


output_csv = "new_final_evaluation_batch1.csv"
output_csv = "new_final_evaluation_batch2.csv"
output_csv = "new_final_evaluation_batch3.csv"
output_csv = "new_final_evaluation_batch4.csv"
save_qars_to_mturk_csv(all_qar_dict, output_csv)
# read_opennmt_output(pgn_output)
# read_opennmt_output(pgn_and_pretraining_output)
# read_opennmt_output(pgn_and_pretraining_and_glove_output)
# read_opennmt_output(pgn_and_pretraining_and_glove_and_lm_output)
# read_opennmt_output(pgn_and_pretraining_and_glove_and_am_output)



